{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "47020873",
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = [\n",
    "    #数字\n",
    "    ['5', '2', '4', '8', '6', '2', '3', '6', '4'],\n",
    "    ['4', '8', '5', '6', '9', '5', '5', '6'],\n",
    "    ['1', '1', '5', '2', '3', '3', '8'],\n",
    "    ['3', '6', '9', '6', '8', '7', '4', '6', '3'],\n",
    "    ['8', '9', '9', '6', '1', '4', '3', '4'],\n",
    "    ['1', '0', '2', '0', '2', '1', '3', '3', '3', '3', '3'],\n",
    "    ['9', '3', '3', '0', '1', '4', '7', '8'],\n",
    "    ['9', '9', '8', '5', '6', '7', '1', '2', '3', '0', '1', '0'],\n",
    "\n",
    "    #字母中夹杂了一些数字\n",
    "    ['a', 't', 'g', 'q', 'e', 'h', '9', 'u', 'f'],\n",
    "    ['e', 'q', 'y', 'u', 'o', 'i', 'p', 's'],\n",
    "    ['q', 'o', '9', 'p', 'l', 'k', 'j', 'o', 'k', 'k', 'o', 'p'],\n",
    "    ['h', 'g', 'y', 'i', 'u', 't', 't', 'a', 'e', 'q'],\n",
    "    ['i', 'k', 'd', 'q', 'r', 'e', '9', 'e', 'a', 'd'],\n",
    "    ['o', 'p', 'd', 'g', '9', 's', 'a', 'f', 'g', 'a'],\n",
    "    ['i', 'u', 'y', 'g', 'h', 'k', 'l', 'a', 's', 'w'],\n",
    "    ['o', 'l', 'u', 'y', 'a', 'o', 'g', 'f', 's'],\n",
    "    ['o', 'p', 'i', 'u', 'y', 'g', 'd', 'a', 's', 'j', 'd', 'l'],\n",
    "    ['u', 'k', 'i', 'l', 'o', '9', 'l', 'j', 's'],\n",
    "    ['y', 'g', 'i', 's', 'h', 'k', 'j', 'l', 'f', 'r', 'f'],\n",
    "    ['i', 'o', 'h', '9', 'n', '9', 'd', '9', 'f', 'a', '9'],\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fe303ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(9, 30)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#生成字典\n",
    "zidian = {}\n",
    "for doc in docs:\n",
    "    for word in doc:\n",
    "        if word not in zidian:\n",
    "            zidian[word] = len(zidian)\n",
    "\n",
    "zidian['0'], len(zidian)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "44f6507a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "113"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "import torch.nn as nn\n",
    "\n",
    "\n",
    "#定义数据\n",
    "class DocDataset(Dataset):\n",
    "    def __init__(self):\n",
    "        xs = []\n",
    "        ys = []\n",
    "\n",
    "        for doc in docs:\n",
    "\n",
    "            #句子根据字典编码成数字\n",
    "            doc_encode = [zidian[word] for word in doc]\n",
    "\n",
    "            #每个句子从头遍历到倒数第5个字母\n",
    "            for i in range(0, len(doc) - 4):\n",
    "\n",
    "                #x是5个字母的前后各两个\n",
    "                xs.append([\n",
    "                    doc_encode[i + 0], doc_encode[i + 1], doc_encode[i + 3],\n",
    "                    doc_encode[i + 4]\n",
    "                ])\n",
    "\n",
    "                #y是5个字母中间的那个\n",
    "                ys.append(doc_encode[i + 2])\n",
    "\n",
    "        self.xs = torch.LongTensor(xs)\n",
    "        self.ys = torch.LongTensor(ys)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        return self.xs[i], self.ys[i]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.xs)\n",
    "\n",
    "\n",
    "len(DocDataset())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "69a1a714",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[19, 21, 16, 18],\n",
       "         [13, 18, 19, 20],\n",
       "         [22, 10, 12, 10],\n",
       "         [14, 15, 16, 17]]),\n",
       " torch.Size([4, 4]),\n",
       " tensor([20, 16, 17,  6]),\n",
       " torch.Size([4]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#数据加载器\n",
    "def get_dataloader():\n",
    "    dataloader = DataLoader(dataset=DocDataset(),\n",
    "                            batch_size=4,\n",
    "                            shuffle=True,\n",
    "                            drop_last=True)\n",
    "    return dataloader\n",
    "\n",
    "\n",
    "for i, data in enumerate(get_dataloader()):\n",
    "    sample = data\n",
    "    break\n",
    "\n",
    "sample[0], sample[0].shape, sample[1], sample[1].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5e258165",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[ 6.5888e-01,  4.7118e-01,  3.9229e-01, -1.9138e-01,  7.6326e-03,\n",
       "          -3.5269e-01, -4.4513e-01,  3.8955e-02,  5.0908e-01,  5.1117e-01,\n",
       "           5.5449e-01,  6.1092e-01,  6.8060e-01,  5.9301e-01,  4.1114e-02,\n",
       "           5.6111e-01, -2.9875e-01,  6.0963e-01,  7.2298e-02,  3.1759e-02,\n",
       "           3.4617e-01, -2.6114e-01,  3.6275e-01, -5.3237e-01,  4.9234e-01,\n",
       "          -2.0311e-01, -5.6541e-02, -2.0027e-01, -1.5255e-01, -5.4809e-01],\n",
       "         [ 6.5247e-01,  4.8950e-01,  3.7930e-01, -1.9945e-01,  1.2645e-02,\n",
       "          -3.3973e-01, -4.3826e-01,  5.5761e-02,  4.9318e-01,  5.0533e-01,\n",
       "           5.6777e-01,  6.0805e-01,  6.7974e-01,  5.7527e-01,  5.2539e-02,\n",
       "           5.6405e-01, -2.9518e-01,  6.3456e-01,  5.6624e-02,  4.9485e-02,\n",
       "           3.5044e-01, -2.8184e-01,  3.5403e-01, -5.5476e-01,  5.0248e-01,\n",
       "          -1.9005e-01, -3.5914e-02, -2.0255e-01, -1.6804e-01, -5.6836e-01],\n",
       "         [ 6.6266e-01,  4.8247e-01,  4.2279e-01, -2.4316e-01, -1.2265e-02,\n",
       "          -3.3076e-01, -3.9279e-01,  3.2513e-02,  4.7192e-01,  4.7845e-01,\n",
       "           6.0295e-01,  6.0709e-01,  6.7979e-01,  5.6652e-01,  3.8781e-02,\n",
       "           5.2995e-01, -2.5546e-01,  5.9928e-01,  9.1362e-02,  3.8809e-02,\n",
       "           3.3528e-01, -2.4955e-01,  3.5951e-01, -5.1847e-01,  4.7969e-01,\n",
       "          -2.3112e-01, -7.2104e-02, -2.1231e-01, -1.5240e-01, -5.5793e-01],\n",
       "         [ 6.4164e-01,  5.0985e-01,  3.4641e-01, -1.8595e-01,  2.9236e-02,\n",
       "          -3.3201e-01, -4.5370e-01,  8.2480e-02,  4.8863e-01,  5.1281e-01,\n",
       "           5.6321e-01,  6.0585e-01,  6.7892e-01,  5.6302e-01,  6.9724e-02,\n",
       "           5.8314e-01, -3.1094e-01,  6.7459e-01,  2.5441e-02,  7.1028e-02,\n",
       "           3.6166e-01, -3.1652e-01,  3.4332e-01, -5.9291e-01,  5.2281e-01,\n",
       "          -1.5825e-01,  5.5336e-04, -1.9997e-01, -1.8989e-01, -5.9215e-01]],\n",
       "        grad_fn=<AddmmBackward>), torch.Size([4, 30]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class CBOW(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.embed = torch.nn.Embedding(30, 2)\n",
    "        self.embed.weight.data.normal_(0, 0.1)\n",
    "\n",
    "        self.fc = torch.nn.Linear(2, 30)\n",
    "\n",
    "    def forward(self, x):\n",
    "\n",
    "        #[b,4] -> [b,4,2]\n",
    "        x = self.embed(x)\n",
    "\n",
    "        #[b,4,2] -> [b,2]\n",
    "        x = torch.mean(x, dim=1)\n",
    "\n",
    "        #[b,2] -> [b,30]\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x\n",
    "\n",
    "\n",
    "model = CBOW()\n",
    "\n",
    "out = model(sample[0])\n",
    "\n",
    "out, out.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4c12d416",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 tensor(3.4605, grad_fn=<NllLossBackward>)\n",
      "200 tensor(3.1750, grad_fn=<NllLossBackward>)\n",
      "400 tensor(3.4084, grad_fn=<NllLossBackward>)\n",
      "600 tensor(3.3346, grad_fn=<NllLossBackward>)\n",
      "800 tensor(2.3855, grad_fn=<NllLossBackward>)\n"
     ]
    }
   ],
   "source": [
    "criteon = torch.nn.CrossEntropyLoss()\n",
    "optim = torch.optim.SGD(model.parameters(), lr=1e-2)\n",
    "\n",
    "model.train()\n",
    "for epoch in range(1000):\n",
    "    for i, data in enumerate(get_dataloader()):\n",
    "        x, y = data\n",
    "        optim.zero_grad()\n",
    "\n",
    "        #计算\n",
    "        #[b,4] -> [b,30]\n",
    "        out = model(x)\n",
    "\n",
    "        loss = criteon(out, y)\n",
    "        loss.backward()\n",
    "        optim.step()\n",
    "\n",
    "    if epoch % 200 == 0:\n",
    "        print(epoch, loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3718d127",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXIAAAD5CAYAAAA6JL6mAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAcOElEQVR4nO3de5RU1Zn+8e9LA30DVKC9gkKAIWEweOmAoiYZZAwQEy9BIipq4lrIqDNqolHj+NMko9HFZMZMjJlBjNGIRpdKvJGoeBkGxUujiCLIEDMqRKUR5NbQdHe9vz92k75V003X6Tp1qp7PWrXoOl11zlvdzVO79tlnb3N3REQkuXrEXYCIiGRGQS4iknAKchGRhFOQi4gknIJcRCThFOQiIgnXM9MdmFkJsAgobtzfQ+5+/Z6eM3DgQB8yZEimhxYRKShLly7d4O4VrbdnHORALTDB3beZWS9gsZn9wd1fbu8JQ4YMoaqqKoJDi4gUDjN7P932jIPcwxVF2xrv9mq86SojEZEsiaSP3MyKzGwZsB54xt1fiWK/IiLSsUiC3N0b3P0IYBAw1sxGt36Mmc00syozq6quro7isCIiQsSjVtz9M+AFYFKa781x90p3r6yoaNNXLyIiXZRxkJtZhZnt2/h1KTARWJXpfltzh+XL4cUXYefOqPcuIpJcUYxaOQi428yKCG8MD7r7ExHs96/+9CeYMgXWrYOiIkil4Fe/gnPOifIoIiLJFMWoleXAkRHUklYqBSeeCB98EFrlu114IRx+OIwZ011HFhFJhpy/svOll2DjxpYhDlBbC7ffHk9NIiK5JOeDfMMGMGu7vaEB/vKX7NcjIpJrcj7Ix4+HXbvabi8vh298I/v1iIjkmpwP8v33h6uuCsG9W2kpDBkCM2bEVpaISM6IYtRKt7vhBhg7Fm67DTZtgqlTYdasEOgiIoUuEUEOYfjhlClxVyEikntyvmtFRET2TEEuIpJwCnIRkYRTkIuIJJyCXEQk4RTkIiIJpyAXEUk4BbmISMIpyEVEEk5BLiKScApyEZGES8xcK5K7XnsNHnggfH3mmVBZGW89IoVGQS4Z+eEP4ec/hx07wgIgv/oVXHYZ3Hhj3JWJFA51rUiXrVgBt94KNTVhKb5UKnz97/8OK1fGXZ1I4VCQS5c9/jjU1bXdXl8Pjz2W/XpECpWCXLqsuBiKitpuLyoK3xOR7FCQS5dNnQo92vkLmjo1u7WIFDIFuXTZ4MHh5GZJSVhTtU+f8PV//RcMGhR3dSKFQ6NWJCPnnReW4HvyyXD/5JNh4MB4axIpNApyyVhFBZx/ftxViBSujLtWzGywmT1vZivNbIWZXRpFYSIi0jlRtMjrge+7++tm1hdYambPuPs7EexbREQ6kHGL3N0/cvfXG7/eCqwEDsl0vyIi0jmRjloxsyHAkcArUe5XRETaF1mQm1kf4GHgMnffkub7M82sysyqqqurozqsiEjBiyTIzawXIcTnufsj6R7j7nPcvdLdKysqKqI4rDSqqwvD/+65B/7857irEZFsy/hkp5kZcCew0t3/LfOSZG+sWAETJoTZB1MpaGiAmTPDZFZmcVcnItkQRYv8OGAGMMHMljXepkSwX+mAe7gAZ/162LoVtm+HnTvhzjvh97+PuzoRyZaMW+TuvhhQ2y8Gy5bBhg1tt2/fHi6dP+20rJckIjHQXCsJtmNH+5NWbd+e3VpEJD4K8gSrrEzfD15WBmedlf16RCQeCvIE690bfvObENy9eoVt5eUwejRccEGspYlIFmnSrIQ79VR4802YOxc+/hgmT4bTT28KdhHJfwryPDB8ONx8c3aO9eKLYWHl//1fGDcOrrsORo7MzrFFJD0FuXTa738PZ58dFliGcPHRo4/CSy/B4YfHWppIQVMfuXSKO1xySVOIQ7j4aPt2uOqq+OoSEQW5dNKmTZBuihx3WLIk+/VI+CQ0cSIceGC4unfx4rgrkrioa0U6pU+f9ses5+LUOdu2wcMPhxPAxx0Xbvk0ZcGzz8I3v9n0CemTT+BrX4NHHgn/SmFRkEun9O4dlnO7++5wIdJuZWVw9dWxlZXWG2+EFmp9fZiyoLgYTjgBHnssf0bzXH55y24uCPcvvxze0ZIuBUddK9Jpt94KZ5wBJSXQt28I8R/8AL7znbgra+IO3/oWfPZZaJXX14d+/EWLYM6cuKuLTnthvWpV+BlIYVGQS6cVF4cW+bp1YRji+vVw/fW51WXx7ruhrtZqasJkYvli4MD02wcMyK3fh2SHglz2Wv/+YbhheXnclbS1p9ZoPrVUr746fCJqbvcnJCk86iOXvPL5z4dWaetJw8rKcqsLKFOXXgqbN8Ps2U3bLr8crrgivpokPuYxNFMqKyu9qqoq68eVwvDqq2FYXkND6FIpLw8TjD39dDhpm09qa8PInAMOCOcuJL+Z2VJ3r2y9XS1yyTtjx8L778MDD4SQO/54OPHE/Ow7Li6Gww6LuwqJm4Jc8tJ++8GsWXFXIZIdOtkpIpJwCnJJq6EhdE1MmRKuIHz00fwa9SGST9S1Im24w9Sp8MwzTaM/nnsOpk+HO+6ItzYRaUstcmlj0aKWIQ7h63nz4K234qtLRNJTkEsbTz2VfvHmhgZYuDD79YjIninIpY3+/cOwttZ69QqjQUQktyjIpY2zzko/Za1ZWA9URHKLglzaOPhgeOgh6Nev6da/PyxYEL4WkdyiUSuS1pQpYRbBxYuhZ08YPz5/5vIWyTeRBLmZ/Ro4GVjv7qOj2KfEr7g4XNouIrktqq6V3wCTItqXiIjshUiC3N0XARuj2JeIiOydrJ3sNLOZZlZlZlXV6ZZjFxGRLslakLv7HHevdPfKilxcdl1EJKE0/FBEJOEU5FLQPvgALroIRo+GU06BJUvirkhk70U1/PB+4KvAQDNbC1zv7nm0Zrnkoz/9CY4+OswrU18PK1aEuWTuvjvM/iiSFJEEubtPj2I/kp8++yysl1lUBF/7GvTpE3dFwXXXwdatkEo1baupgYsvDlMRpJumQCQX6crOTkilQhAtWwbDhoWFFtJNKiVt3XsvzJwZrg6F8LN84AH4+tfjrQvg+edbhvhuW7fCunUweHD2axLpCgV5B7ZuhRNOCB/Dd+6E0lLo2zf0pR56aNzV5bb/+78Q4jt2tNx+xhnw4YcwYEDn9vPpp+H3cOih0baS998/LM7cWioF++4b3XFEups+PHbguutg1SrYti30o27dCp98AuefH3dlue/++8PPrLUePWD+/I6f/+mnMGlSmMRr1CgYNAj+8Ifo6vvBD6C8vOW24mI47bTwZt3cli1wwQVQVhbmnJk8Gd57L7paRDKhIO/AvHlQW9tyW0MD/M//hP5Uad/uN7/WGhrSL1zR2uTJYYm5XbtCq/6jj8JJyBUroqnvrLPgyivDp6x+/aCkBE46CebObfk49zDnzL33hjrq60NX27hxof9fJG4K8g7sacFhLUa8Z9/4RgjJdKZM2fNz3347BHZdXcvttbVw662RlIcZXH99+IT17LOh++yxx9q20l96KXwq27WraVsqFd7I77knmlpEMqEg78C3vw29e7fc1qMHHHts2//w0tK4cXDmmU0/J7PQNXHZZTBixJ6f++GH6afNbWgIgduRjRvhzjvh5z8PIbwnfftCZWXowkln1ar0J0VrasIJcJG46WRnB268MXy8X7s2dBX06RPC6K674q4sWrW1Iazaa0F3hVnoppg+He67LwTzjBlw/PEdP3f06PDzbq2kBCZM2PNzn3oqDB80C90g11wDs2bBz34Wtu2tL3wh/fPKyuCII/Z+fyJRM4+hf6CystKrqqqyftyuqq+Hxx8Pra/PfS6Muigri7uqaHz0UTiJt3BhCPJjj4Vf/7rjFnN3O//80Cfd0NBy+8CBoYXcvz88+CDMnh26RiZOhBtugIoKOOCAtm8C5eWh26SjN4F03MOnizffbOpe6dEj1LBmDeyzT1deocjeM7Ol7l7Z5hvunvXb0Ucf7RK/+nr3oUPde/Z0D3HlbuY+YID7li3x1fXxx+7FxU01Nb9Nnx4e8+Mfu5eXN20vKnLfbz/3u+5y79cv/XPPPbfrNW3e7H7BBe6lpeHnNXmy+3vvRfJyRToNqPI0mao+8gL2xz/Chg0tR5a4h/Hy990XX11r1oQulHTefTcMAf3pT1uOfGloCK3whx9uf7/pRtB0Vr9+oZuopiacgF2wAIYO7fr+RKKkIC9ga9a0HImx2/btHZ8g7E7DhoU3k9aKikKf9MqV6U+E1tWFi5DSBXZ5OZx9dtSViuQGBXkBGzMmfSD26RMmk4rLgQeG0UKtT7yWlISLeA46KP0bkFl4E7jrrvDc3r3DtvLycPJz8uTs1C+SbRq1UsC+8pUwImP58qaLnnr1CicM4579b+7ccCXnL38ZulIqK+EXv4CRI8P3v/rVMFdK84u1SktD0I8fD8ccE64s3bIFTj453O/KiBWRJNColQK3bRv88z+HESL19aHlesstIcxz2ZYtYWTLggVhQq7SUrjtttCSF8lX7Y1aUZBLom3aFG6HHRb60EXyWXtBrj7yhHr++dCF0L9/6DZYuDDuiuKx335hbL9CXAqZgjyB/vjH0O+7ZElojb7ySlim7Ikn4q5MROKgIE+gK65oO/NiTQ1873vx1CMi8VKQJ1B7Y7zXrNGMjCKFSEGeQAcckH57RYWG2IkUIgV5Al17bdtJu8rK4Ic/zHzfdXXpp2wVkdylIE+gf/iHsCDCPvuEqx379Qtjwf/pn7q+zzfegLFjw/7KysIY7a1bIytZRLqRxpEnWH19WEChf/+mVeq7Yu3asCZm8+AuLg7DGl94IeMyRSQiGkeeh3r2DCvBZxLiEC6Db70uaW0tvPZaWHJNRHKbglx46630k1D17AmrV2e/HhHZOwpy+WvfeGt1dWHJNRHJbZEEuZlNMrN3zWyNmV0dxT4le2bNCic4ezT7aygthb//e/ibv4mvLhHpnIyD3MyKgF8Ck4FRwHQzG5XpfiV79t8/XOb/9a+HAB8wIKx0/+CDcVcmIp0RxXzkY4E17v4egJn9DjgFeCeCfUuWDB8eFicWkeSJomvlEODDZvfXNm5rwcxmmlmVmVVVV1dHcFgREYFogjzdReFtBqe7+xx3r3T3yopcX7VARCRBogjytcDgZvcHAX+JYL8iItIJUfSRvwaMMLOhwDrgTOCsCPYrhW716jDJenFxWIPuoIPirkgkJ2XcInf3euAS4ClgJfCgu6/IdL9S4G64AY44IswEduWVMGxYWE1ZRNrQXCuSe15/HU44oe3qGSUlYWKYAQPiqUskZpprRZLj/vth586223v21Hp2ImkoyCX3pFLplzpy12TpImkoyCX3TJsWLjFtraEhXH4qIi0oyKXLNm2CG2+EL38ZzjkHIjvtMW5cWD2jrAyKiqB37xDst90W5hMQkRZ0slO6ZMMGOPLI8O/OnWHCrZISmDsXpk+P6CBvvgmPPhp2PG0aDBkS0Y5Fkqm9k51RjCOXAjR7Nqxf3zSPeSoVBplcdBFMnQq9ekVwkDFjwk1E9qhgulZ0jixajz+efjGKhgZYuTL79YgUsrwP8rlz4eCDQ1froYfCvffGXVF+aG8od10d7LdfdmtJhM62JOrr4aab4MADwzmCk06CFbq+TvYsr4N87ly49FL46KNw/8MP4cIL4YEH4q0rH1x+OZSXt9zWsyccdRQMHpz+OQXHPfRBDRwYWhIjRsCTT+75OTNnhjPIn3wCO3bAwoVw7LHw/vvZqVkSKa+D/Lrr2l4cWFMD114bTz355LTT4IorwnnIffYJjcfDD4dHHom7shzyox+FqQY+/TTcX7MGzjgDnn8+/eM//hjuu6/lH617OJv8r//a7eVKcuXtqJWGhvZXl+/ZM3QBSOY2boSlS8N8Vlrfs5ldu0L/07Ztbb933HGweHHb7S+8AKeeCps3t/3euHHw8stRVykJU3CjVoqK4JBDYN26tt8bNiz79eSr/v3D2p7SSnV1aE2k8+676bcPGwa1tW23FxXpXVL2KK+7Vm66KXzkb66sDG6+OZ56pIBUVLT/kfBv/zb99sGDYfLk0F/VXHFx6McSaUdeB/m558Idd8DQoeH/1IgRYdTKqafGXZnslbo6uP76cFVnWVm4TH/16rir2rPeveGaa9q2JEpL4V/+pf3n3XcffPe74XE9eoQTD08/DZ//fPfWK4mWt33kkkemTQuzHu7YEe6bhTOs77yT24tNuMPtt8NPfxpGoYwaBT/7GUyc2PFzU6nwBlZc3P11SmJoGltJpj//OVx9tDvEIQTkjh3wi1/EV1dnmMHFF4c51OvqwpQDnQlxCK1xhbh0koJccts776QPtNpaePXV7NcjkoMU5JLbhg9PPxdAr16h/1hEFOSS40aOhOOPTz+S49JL46lJJMcoyCX3zZ8fJjwvLg59x0cdBc89p2ltRRpp1IokRyoVJpXq3TvuSiTbdi/zV1QUdyWx0qgVSb4ePRTihWb79jCRWFlZ+N0fcwwsWxZ3VTlHQS4iueuUU+C3vw0Th6VS8MorYW3BtWvjriynKMhFJDetXAkvvRRCvLna2rB+q/yVglxEctPq1enXDNy1K1xcJX+VUZCb2RlmtsLMUmbWpgNeRKTLRo1KP990cTF86UvZryeHZdoifxs4HVgUQS0iIk1GjAhzJJeWNm0zC/cvuii+unJQRkHu7ivdvZ3JlUVEMvTgg/CP/xgWgi0uhkmTwgnPAw+Mu7KckrcLS4hIHiguhltuCTdpV4dBbmYLgXRvf9e6+6OdPZCZzQRmAhx66KGdLlBERPaswyB3907Ou9nhfuYAcyBc2RnFPkVERMMPRUQSL9Phh6eZ2VrgWOBJM3sqmrJERKSzMjrZ6e7zgfkR1SIiIl2grhURkYRTkIuIJJyCXEQk4RTkIiIJpyAXEUk4BbmISMIpyEVEEk5BLiKScApyEZGEU5CLiCScglxEJOEU5CIiCacgFxFJOAW5iEjCKchFRBJOQS4iknAKchGRhFOQi4gknIJcRCThFOQiIgmX0eLLIh167TV49VUYNAimTIFeveKuSCTvKMile+zaBaeeCv/935BKhQDv0wcWL4bPfS7u6kSya+tWuPdeeP11OPxwOPdc2HffyHavIJfu8R//EUK8pibc37kTtm+Hb387tNJFCsXatfClL8GWLeH/Q1kZ/PjHsGQJjBgRySHURy7d4447mkJ8t1QK3noLPv44nppE4nDZZVBd3fT/oaYGNm2CCy+M7BAK8jhs3w7vvw91dXFX0n3ae209euT36xZpbcECaGhouS2VgkWL2m7vIgV5NtXVwUUXQUUFjBoV/r399rir6h7Tp0NJSdvtgweHE58ihaK9E/xFRWAWySEyCnIzm21mq8xsuZnNN7N9I6kqX33ve3D33bBjR/h4tXkzXHklPPxw3JVF7+qrYfjwcIIToLQU+vaFefMi++MVSYRzzoHi4pbbeveG008Pn1AjYO7e9SebnQQ85+71ZnYLgLtf1dHzKisrvaqqqsvHTaSdO6F//xDirY0ZA8uWZb2kbldXB/Pnh5EqQ4fCjBkwcGDcVYlk17ZtMHEivP02uIfwHjoUXnghZMJeMLOl7l7ZentGo1bc/elmd18Gpmayv7y2aVP4Jaazbl12a8mWXr1g2rRwEylUffqEESpLloQwHzkSvvzlSD+ZRjn88LvAAxHuL7/sv3/4he7c2XK7GYwdG09NIpIdZjB+fLh1gw47aMxsoZm9neZ2SrPHXAvUA/P2sJ+ZZlZlZlXV1dXRVJ8kRUUwe3YYQ7qbWbh/003x1SUiiZdRHzmAmZ0HzAJOdPeajh4PBdpHvtuCBeFigA8+CBcJ/OQn8MUvxl2ViCRAt/SRm9kk4CrgK50N8YI3ZUq4iYhEJNOxL7cBfYFnzGyZmf1nBDWJiMheyHTUyvCoChERka7RlZ0iIgmnIBcRSTgFuYhIwinIRUQSTkEuIpJwCnIRkYRTkIuIJJyCXEQk4RTkIiIJpyAXEUk4BbmISMIpyPOJe/urEIlI3lKQ54P33oOTTgpLq5WWwnnnhYWdRaQgRLnUm8Rh82YYNw42boRUChoa4He/C2sDVlVpxXqRAqAWedLdcw/U1IQQ323XLli9Gl58Mb66RCRrFORJ98YbIchbS6Vg5crs1yMiWacgT7qjjmq5oPNuZjBqVPbrEZGsU5An3bnnQnk59Gj2qywuDiE+fnx8dYlI1ijIk65fP3j11bCgc+/eIdRnzIBnn9WJTpECoVEr+WDIEHj88birEJGYqEUuIpJwapFL/qqvD59Unn4aDjgAvvMdOOywuKsSiZyCXPJTbS1MmADLl8O2beH8wezZ8NBDMHly3NWJREpdK5Kf7rwTli0LIQ7hIqmaGjj7bKiri7U0kagpyCU//fa36S+UamiApUuzX49IN1KQS34qLU2/PZWCkpLs1iLSzTIKcjP7iZktN7NlZva0mR0cVWEiGbnwwjCmvrUBA2DMmOzXI9KNMm2Rz3b3L7r7EcATwP/LvCSRCEybBmedFVrmZWXQt28I8cce04VSkncyGrXi7lua3S0HtKqB5AYzmDMHvv99WLQIBg4MV78WF8ddmUjkMh5+aGY3AucCm4G/y7gikSiNHBluInmsw64VM1toZm+nuZ0C4O7XuvtgYB5wyR72M9PMqsysqrq6OrpXICJS4MwjWuPRzA4DnnT30R09trKy0quqqiI5rohIoTCzpe5e2Xp7pqNWRjS7+01gVSb7ExGRvZdpH/nNZjYSSAHvA7MyL0lERPZGpqNWvhVVISIi0jWR9ZHv1UHNqgkt+GwbCGyI4bjdRa8nt+n15LYkvp7D3L2i9cZYgjwuZlaV7kRBUun15Da9ntyWT69Hc62IiCScglxEJOEKLcjnxF1AxPR6cpteT27Lm9dTUH3kIiL5qNBa5CIieafggtzMZpvZqsZ51Oeb2b5x15QJMzvDzFaYWcrMEnkG3swmmdm7ZrbGzK6Ou55MmdmvzWy9mb0ddy1RMLPBZva8ma1s/Fu7NO6aMmFmJWb2qpm92fh6fhR3TZkquCAHngFGu/sXgdXANTHXk6m3gdOBRXEX0hVmVgT8EpgMjAKmm9moeKvK2G+ASXEXEaF64Pvu/gXgGODihP+OaoEJ7j4GOAKYZGbHxFtSZgouyN39aXevb7z7MjAoznoy5e4r3f3duOvIwFhgjbu/5+67gN8Bp8RcU0bcfRGwMe46ouLuH7n7641fbwVWAofEW1XXedC4Kje9Gm+JPllYcEHeyneBP8RdRIE7BPiw2f21JDgk8p2ZDQGOBF6JuZSMmFmRmS0D1gPPuHuiX0/GC0vkIjNbCByY5lvXuvujjY+5lvCRcV42a+uKzryeBEu37lqiW0f5ysz6AA8Dl7VaHSxx3L0BOKLxHNl8Mxvt7ok9p5GXQe7uE/f0fTM7DzgZONETMP6yo9eTcGuBwc3uDwL+ElMt0g4z60UI8Xnu/kjc9UTF3T8zsxcI5zQSG+QF17ViZpOAq4BvuntN3PUIrwEjzGyomfUGzgQei7kmacbMDLgTWOnu/xZ3PZkys4rdo9XMrBSYSMLXUii4IAduA/oCz5jZMjP7z7gLyoSZnWZma4FjgSfN7Km4a9objSeeLwGeIpxEe9DdV8RbVWbM7H5gCTDSzNaa2QVx15Sh44AZwITG/zPLzGxK3EVl4CDgeTNbTmhIPOPuT8RcU0Z0ZaeISMIVYotcRCSvKMhFRBJOQS4iknAKchGRhFOQi4gknIJcRCThFOQiIgmnIBcRSbj/D50zQy/swQ1yAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "colors = []\n",
    "idxs = []\n",
    "for word, idx in zidian.items():\n",
    "    idxs.append(idx)\n",
    "    if word in '1234567890':\n",
    "        colors.append('red')\n",
    "        continue\n",
    "    colors.append('blue')\n",
    "\n",
    "#[30] -> [30,2]\n",
    "embed = model.embed(torch.LongTensor(idxs)).detach().numpy()\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.scatter(embed[:, 0], embed[:, 1], c=colors)\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
